#include "../thirdparty/nccl-sg/src/include/nccl_net.h"
#include "transport.h"
#include "transport_config.h"
#include <glog/logging.h>
#include <atomic>
#include <fstream>
#include <iostream>
#include <mutex>
#include <string>
#include <thread>
#include <unistd.h>

using namespace uccl;

char const* PLUGIN_NAME = "EFA_Plugin";

class UcclRequestBuffPool : public BuffPool {
  static constexpr size_t num_elements =
      kMaxUnconsumedRxMsgbufs;  // Send and receive.
  static constexpr size_t element_size = sizeof(UcclRequest);

 public:
  UcclRequestBuffPool() : BuffPool(num_elements, element_size, nullptr) {}
  ~UcclRequestBuffPool() = default;
};

Endpoint* ep;

enum ConnState { kConnInit = 0, kConnConnecting, kConnConnected };

struct UcclBaseComm {
  int vdev;
  ConnID conn_id;
  std::shared_ptr<UcclRequestBuffPool> uccl_req_pool;
};

struct AsyncAcceptState {
  struct UcclBaseComm base;
  std::string remote_ip_str;
  int remote_vdev;
};

struct AsyncConnectState {
  struct UcclBaseComm base;
};

// Handle generated by pluginListen. And then trasfered to remote side,
// Remote side will call pluginConnect() with this handle.
struct UcclHandle {
  uint32_t ip_addr_u32;
  int remote_vdev;
  int listen_port;
  std::atomic<enum ConnState> state = kConnInit;
  AsyncConnectState connect_buffer;
  std::atomic<bool> fence = false;
};
static_assert(sizeof(struct UcclHandle) < NCCL_NET_HANDLE_MAXSIZE,
              "UcclHandle size too large");

// Hanlde generated by pluginListen for pluginAccept() to use.
struct UcclListenComm {
  int vdev;
  int remote_dev;
  int listen_fd;
  std::atomic<enum ConnState> state = kConnInit;
  AsyncAcceptState accept_buffer;
  std::atomic<bool> fence = false;
};

static void write_barrier(std::atomic<bool>& fence) {
  std::atomic_thread_fence(std::memory_order_release);
  std::atomic_store_explicit(&fence, true, std::memory_order_relaxed);
}
static void read_barrier(std::atomic<bool>& fence) {
  std::ignore = std::atomic_load_explicit(&fence, std::memory_order_relaxed);
  std::atomic_thread_fence(std::memory_order_acquire);
}

// Handle generated by pluginAccept.
struct UcclRecvComm {
  struct UcclBaseComm base;
  std::string remote_ip_str;
  int remote_vdev;
};

// Handle generated by pluginConnect.
struct UcclSendComm {
  struct UcclBaseComm base;
};

ncclResult_t pluginInit(ncclDebugLogger_t logFunction) {
  google::InitGoogleLogging("UCCL");
  google::InstallFailureSignalHandler();

  int gpu;
  cudaGetDevice(&gpu);

#ifdef LAZY_CREATE_ENGINE
  ep = new Endpoint();
#else
  ep = new Endpoint(gpu);
#endif

  return ncclSuccess;
}

ncclResult_t pluginDevices(int* ndev) {
  // To ease NIC-GPU mapping on p4d, we virtualize each NIC into two.
  *ndev = NUM_DEVICES;
  return ncclSuccess;
}

/// @ref ncclIbGetPciPath
ncclResult_t pluginPciPath(char const* ib_name, char** path) {
  char devicePath[256];
  snprintf(devicePath, 256, "/sys/class/infiniband/%s/device", ib_name);
  char* p = realpath(devicePath, NULL);
  if (p == NULL) {
    LOG(ERROR) << "Could not find device path for " << ib_name;
    return ncclInternalError;
  }
  *path = p;
  return ncclSuccess;
}

ncclResult_t pluginGetProperties(int pdev, ncclNetProperties_v8_t* props) {
#ifdef LAZY_CREATE_ENGINE
  auto factory_dev = EFAFactory::GetEFADevice(pdev);
#else
  auto factory_dev = EFAFactory::GetEFADevice(ep->gpu_);
#endif
  props->name = factory_dev->ib_name;

  // Speed in *Mbps*. 100000 means 100G
  props->speed = kLinkBandwidth * 8 / 1e6;

  pluginPciPath(factory_dev->ib_name, &props->pciPath);

  // Only used to detect NICs with multiple PCI attachments.
  props->guid = factory_dev->dev_attr.sys_image_guid + pdev;
  LOG(INFO) << "pluginGetProperties dev " << pdev << " guid " << props->guid
            << " name " << props->name << " pciPath " << props->pciPath;

  props->ptrSupport = NCCL_PTR_HOST;
  if (factory_dev->dma_buf_support)
    props->ptrSupport |= NCCL_PTR_CUDA | NCCL_PTR_DMABUF;

  // If you regMr has a fast registration cache, set to 1. If set to 0, user
  // buffer registration may be disabled.
  props->regIsGlobal = 0;

  // Port number, used in conjunction with guid
  props->port = EFA_PORT_NUM;
  // Custom latency (used to help tuning if latency is high. If set to 0, use
  // default NCCL values.
  props->latency = 0;
  // Maximum number of comm objects we can create.
  props->maxComms = 1024 * 1024;
  // Maximum number of receive operations taken by irecv().
  // props->maxRecvs = kMaxMultiRecv;
  // Yang: to make alltoall nvlink on work!
  props->maxRecvs = 1;
  // Coupling with NCCL network device-side code.
  props->netDeviceType = NCCL_NET_DEVICE_HOST;
  props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION;
  return ncclSuccess;
}

// To create a connection, NCCL will start by calling listen on the receiver
// side. This function takes a device number as input argument, and should
// return a local listenComm object, and a handle to pass to the other side, so
// that the sender side can connect to the receiver. The handle is a buffer of
// size NCCL_NET_HANDLE_MAXSIZE and is provided by NCCL. This call should never
// block, but contrary to connect and accept, listenComm should never be NULL if
// the call succeeds.
ncclResult_t pluginListen(int vdev, void* opaque_handle, void** listenComm) {
  int gpu_idx = 0;
  cudaGetDevice(&gpu_idx);
  if (vdev != gpu_idx) {
    LOG_FIRST_N(INFO, 1) << "pluginListen detects different vdev " << vdev
                         << " vs. gpu_idx " << gpu_idx
                         << ", forcely setting vdev to gpu_idx";
    vdev = gpu_idx;
  }

  LOG(INFO) << "[pluginListen] pid=" << getpid() << ", using GPU " << gpu_idx
            << ", vdev=" << vdev;

  if (vdev != gpu_idx) {
    LOG_FIRST_N(INFO, 1) << "pluginListen detects different vdev " << vdev
                         << " vs. gpu_idx " << gpu_idx
                         << ", forcely setting vdev to gpu_idx";
    vdev = gpu_idx;
  }
  ep->initialize_engine_by_gpu_idx(gpu_idx);
  auto pdev = get_pdev(vdev);
  struct UcclHandle* handle = (struct UcclHandle*)opaque_handle;
  memset(handle, 0, sizeof(struct UcclHandle));

  auto [listen_port, listen_fd] = ep->uccl_listen();

  // Fill out handle which will be passed to the other side.
#ifdef LAZY_CREATE_ENGINE
  auto factory_dev = EFAFactory::GetEFADevice(pdev);
#else
  auto factory_dev = EFAFactory::GetEFADevice(ep->gpu_);
#endif
  handle->ip_addr_u32 = str_to_ip(factory_dev->local_ip_str);
  handle->remote_vdev = vdev;
  handle->listen_port = listen_port;

  struct UcclListenComm* lcomm =
      (struct UcclListenComm*)calloc(1, sizeof(struct UcclListenComm));

  lcomm->vdev = vdev;
  lcomm->state = kConnInit;
  lcomm->listen_fd = listen_fd;
  *listenComm = lcomm;

  LOG(INFO) << "pluginListen on vdev: " << vdev << " listen_port "
            << listen_port << " listen_fd " << listen_fd << " gpu_idx "
            << gpu_idx;

  return ncclSuccess;
}

// NCCL will use its bootstrap infrastructure to provide the handle to the
// sender side, then call connect on the sender side on a given device index
// dev, providing the handle. connect should not block either, and instead set
// sendComm to NULL and return ncclSuccess. In that case, NCCL will call accept
// again until it succeeds.
ncclResult_t pluginConnect(int vdev, void* opaque_handle, void** sendComm,
                           ncclNetDeviceHandle_v8_t** /*sendDevComm*/) {
  int gpu_idx = 0;
  cudaGetDevice(&gpu_idx);
  if (vdev != gpu_idx) {
    LOG_FIRST_N(INFO, 1) << "pluginConnect detects different vdev " << vdev
                         << " vs. gpu_idx " << gpu_idx
                         << ", forcely setting vdev to gpu_idx";
    vdev = gpu_idx;
  }

  auto pdev = get_pdev(vdev);
  struct UcclHandle* handle = (struct UcclHandle*)opaque_handle;

  std::string remote_ip_str = ip_to_str(handle->ip_addr_u32);

  struct UcclSendComm* scomm =
      (struct UcclSendComm*)calloc(1, sizeof(struct UcclSendComm));

  if (handle->state == kConnInit) {
    LOG(INFO) << "pluginConnect on vdev: " << vdev << " remote_ip_str "
              << remote_ip_str << " dest_port " << handle->listen_port
              << " gpu_idx " << gpu_idx;
    handle->state = kConnConnecting;
    // Delegate connection to another thread.
    std::thread t = std::thread([vdev, handle, remote_ip_str] {
      handle->connect_buffer.base.conn_id = ep->uccl_connect(
          vdev, handle->remote_vdev, remote_ip_str, handle->listen_port);
      handle->connect_buffer.base.vdev = vdev;
      handle->state = kConnConnected;
      write_barrier(handle->fence);
    });
    t.detach();
    *sendComm = nullptr;
    free(scomm);
  } else if (handle->state == kConnConnecting) {
    *sendComm = nullptr;
    free(scomm);
  } else {
    read_barrier(handle->fence);
    DCHECK(handle->state == kConnConnected);
    scomm->base = handle->connect_buffer.base;
    scomm->base.uccl_req_pool = std::make_shared<UcclRequestBuffPool>();
    *sendComm = scomm;
  }

  if (*sendComm) {
    LOG(INFO) << Format("Connected to %s/%d on vdev %d with flow_id %lu\n",
                        remote_ip_str.c_str(), handle->remote_vdev, vdev,
                        scomm->base.conn_id.flow_id);
  }

  return ncclSuccess;
}

// To finalize the connection, the receiver side will call accept on the
// listenComm returned by the listen call previously. If the sender did not
// connect yet, accept should not block. It should return ncclSuccess, setting
// recvComm to NULL. NCCL will call accept again until it succeeds.
ncclResult_t pluginAccept(void* listenComm, void** recvComm,
                          ncclNetDeviceHandle_v8_t** /*recvDevComm*/) {
  int gpu_idx = 0;
  cudaGetDevice(&gpu_idx);
  struct UcclListenComm* lcomm = (struct UcclListenComm*)listenComm;

  struct UcclRecvComm* rcomm =
      (struct UcclRecvComm*)calloc(1, sizeof(struct UcclRecvComm));

  if (lcomm->state == kConnInit) {
    DCHECK(lcomm->vdev == gpu_idx)
        << "pluginAccept: vdev " << lcomm->vdev << " vs. gpu_idx " << gpu_idx;
    auto vdev = lcomm->vdev;
    LOG(INFO) << "pluginAccept on vdev: " << vdev << " listen_fd "
              << lcomm->listen_fd << " gpu_idx " << gpu_idx;
    lcomm->state = kConnConnecting;
    // Delegate connection to another thread.
    std::thread t = std::thread([lcomm, vdev] {
      std::string remote_ip_str;
      int remote_vdev;
      lcomm->accept_buffer.base.conn_id =
          ep->uccl_accept(vdev, &remote_vdev, remote_ip_str, lcomm->listen_fd);
      lcomm->accept_buffer.base.vdev = vdev;
      lcomm->accept_buffer.remote_ip_str = remote_ip_str;
      lcomm->accept_buffer.remote_vdev = remote_vdev;
      lcomm->state = kConnConnected;
      write_barrier(lcomm->fence);
    });
    t.detach();
    *recvComm = nullptr;
    free(rcomm);
  } else if (lcomm->state == kConnConnecting) {
    *recvComm = nullptr;
    free(rcomm);
  } else {
    read_barrier(lcomm->fence);
    DCHECK(lcomm->state == kConnConnected);
    rcomm->base = lcomm->accept_buffer.base;
    rcomm->base.uccl_req_pool = std::make_shared<UcclRequestBuffPool>();
    rcomm->remote_ip_str = lcomm->accept_buffer.remote_ip_str;
    rcomm->remote_vdev = lcomm->accept_buffer.remote_vdev;
    *recvComm = rcomm;
  }

  if (*recvComm) {
    LOG(INFO) << Format("Accepted from %s/%d on vdev %d with flow_id %lu\n",
                        rcomm->remote_ip_str.c_str(), rcomm->remote_vdev,
                        lcomm->vdev, rcomm->base.conn_id.flow_id);
  }

  return ncclSuccess;
}

static std::atomic<uint32_t> reg_cnt = 0;

ncclResult_t pluginRegMr(void* collComm, void* data, size_t size, int type,
                         void** mhandle) {
  int ret;
  struct UcclBaseComm* base = (struct UcclBaseComm*)collComm;
  auto dev_idx = get_dev_idx_by_engine_idx(base->conn_id.engine_idx);
  auto vdev_idx = base->vdev;
  checkMemoryLocation(data);

  LOG(INFO) << "pluginRegMr, size " << size << " flow_id "
            << base->conn_id.flow_id << " vdev_idx " << vdev_idx << " data ptr "
            << std::hex << data;
  ret = ep->uccl_regmr(dev_idx, data, size, type, (struct Mhandle**)mhandle);
  reg_cnt++;

  return ret == 0 ? ncclSuccess : ncclInternalError;
}

ncclResult_t pluginRegMrDmaBuf(void* collComm, void* data, size_t size,
                               int type, uint64_t offset, int fd,
                               void** mhandle) {
  int ret;
  struct UcclBaseComm* base = (struct UcclBaseComm*)collComm;
  auto dev_idx = get_dev_idx_by_engine_idx(base->conn_id.engine_idx);
  auto vdev_idx = base->vdev;
  checkMemoryLocation(data);

  LOG(INFO) << "pluginRegMrDmaBuf, size " << size << " flow_id "
            << base->conn_id.flow_id << " vdev_idx " << vdev_idx << " data ptr "
            << std::hex << data;
  ret = ep->uccl_regmr_dmabuf(dev_idx, data, size, type, offset, fd,
                              (struct Mhandle**)mhandle);
  reg_cnt++;

  return ret == 0 ? ncclSuccess : ncclInternalError;
}

ncclResult_t pluginDeregMr(void* collComm, void* mhandle) {
  struct UcclBaseComm* base = (struct UcclBaseComm*)collComm;
  ep->uccl_deregmr((struct Mhandle*)mhandle);
  LOG(INFO) << "pluginDeregMr, " << base->conn_id.flow_id;
  if (--reg_cnt == 0) delete ep;
  return ncclSuccess;
}

ncclResult_t pluginIsend(void* sendComm, void* data, int size, int tag,
                         void* mhandle, void** request) {
  // DCHECK(size > 0 && size <= 524288) << "size " << size;
  // DCHECK(size <= 1048576) << "pluginIsend size " << size;
  struct UcclSendComm* scomm = (struct UcclSendComm*)sendComm;
  auto conn_id = scomm->base.conn_id;
  struct Mhandle* mh = (struct Mhandle*)mhandle;
  // checkMemoryLocation(data);
  uint64_t addr;
  auto vdev = scomm->base.vdev;
  if (scomm->base.uccl_req_pool->alloc_buff(&addr)) {
    CHECK(false);
    *request = nullptr;
    return ncclSuccess;
  }

  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(addr);
  req->type = ReqTx;
  req->n = 1;
  req->send_len = size;
  req->poll_ctx = ep->uccl_send_async(conn_id, data, req->send_len, mh);
  req->req_pool = (void*)scomm->base.uccl_req_pool.get();

  *request = req;
  // LOG(INFO) << "pluginIsend on size " << size;

  return ncclSuccess;
}

ncclResult_t pluginIrecv(void* recvComm, int n, void** data, int* sizes,
                         int* tags, void** mhandles, void** request) {
  struct UcclRecvComm* rcomm = (struct UcclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  struct Mhandle** mhs = (struct Mhandle**)mhandles;
  // checkMemoryLocation(data[0]);

  uint64_t addr;
  auto vdev = rcomm->base.vdev;
  if (rcomm->base.uccl_req_pool->alloc_buff(&addr)) {
    CHECK(false);
    *request = nullptr;
    return ncclSuccess;
  }

  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(addr);
  req->type = ReqRx;
  req->n = n;
  req->poll_ctx =
      ep->uccl_recv_multi_async(conn_id, data, req->recv_len, mhs, n);
  req->req_pool = (void*)rcomm->base.uccl_req_pool.get();

  *request = req;

  return ncclSuccess;
}

ncclResult_t pluginIrecvScattered(void* recvComm, int* tags, void* mhandles,
                                  void** request) {
  struct UcclRecvComm* rcomm = (struct UcclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  struct Mhandle* mhs = (struct Mhandle*)mhandles;
  // checkMemoryLocation(data[0]);

  uint64_t addr;
  auto vdev = rcomm->base.vdev;
  if (rcomm->base.uccl_req_pool->alloc_buff(&addr)) {
    CHECK(false);
    *request = nullptr;
    return ncclSuccess;
  }

  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(addr);
  req->type = ReqRxScattered;
  req->n = 1;
  // Using plugin-allocated memory so nccl does not need to manage it.
  req->poll_ctx = ep->uccl_recv_scattered_async(conn_id, req, mhs);
  req->req_pool = (void*)rcomm->base.uccl_req_pool.get();

  *request = req;

  return ncclSuccess;
}

ncclResult_t pluginIrecvFreePtrs(void* recvComm, void* request) {
  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(request);

  struct UcclRecvComm* rcomm = (struct UcclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  ep->uccl_recv_free_ptrs(conn_id, req->iov_n, req->iov_addrs);

  auto uccl_req_pool = reinterpret_cast<UcclRequestBuffPool*>(req->req_pool);
  uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));

  return ncclSuccess;
}

ncclResult_t pluginIflush(void* recvComm, int n, void** data, int* sizes,
                          void** mhandles, void** request) {
  struct UcclRecvComm* rcomm = (struct UcclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  struct Mhandle** mhs = (struct Mhandle**)mhandles;
  // checkMemoryLocation(data[0]);
  uint64_t addr;
  auto vdev = rcomm->base.vdev;
  if (rcomm->base.uccl_req_pool->alloc_buff(&addr)) {
    CHECK(false);
    *request = nullptr;
    return ncclSuccess;
  }

  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(addr);
  req->type = ReqFlush;
  req->n = n;
  req->poll_ctx = ep->uccl_flush_async(conn_id, data, req->recv_len, mhs, n);
  req->req_pool = (void*)rcomm->base.uccl_req_pool.get();

  *request = req;

  return ncclSuccess;
}

ncclResult_t pluginTest(void* request, int* done, int* size) {
  struct UcclRequest* req = reinterpret_cast<struct UcclRequest*>(request);

  if (ep->uccl_poll_once(req->poll_ctx)) {
    *done = 1;
    if (req->type == ReqTx) {
      size[0] = req->send_len;
    } else if (req->type == ReqRx) {
      for (int i = 0; i < req->n; i++) size[i] = req->recv_len[i];
    } else if (req->type == ReqFlush) {
      // Do nothing.
    } else if (req->type == ReqRxScattered) {
      size[0] = req->recv_len[0];
    }
    // request from ReqRxScattered will be freed by pluginIrecvFreePtrs
    if (req->type != ReqRxScattered) {
      auto uccl_req_pool =
          reinterpret_cast<UcclRequestBuffPool*>(req->req_pool);
      uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));
    }

  } else {
    *done = 0;
  }

  return ncclSuccess;
}

ncclResult_t pluginCloseSend(void* sendComm) {
  struct UcclSendComm* scomm = (struct UcclSendComm*)sendComm;
  free(scomm);
  return ncclSuccess;
}
ncclResult_t pluginCloseRecv(void* recvComm) {
  struct UcclRecvComm* rcomm = (struct UcclRecvComm*)recvComm;
  free(rcomm);
  return ncclSuccess;
}
ncclResult_t pluginCloseListen(void* listenComm) {
  struct UcclListenComm* comm = (struct UcclListenComm*)listenComm;
  free(comm);
  return ncclSuccess;
}

volatile ncclNet_v8_t ncclNetPlugin_v8 = {
    .name = PLUGIN_NAME,
    .init = pluginInit,
    .devices = pluginDevices,
    .getProperties = pluginGetProperties,
    .listen = pluginListen,
    .connect = pluginConnect,
    .accept = pluginAccept,
    .regMr = pluginRegMr,
    .regMrDmaBuf = pluginRegMrDmaBuf,
    .deregMr = pluginDeregMr,
    .isend = pluginIsend,
    .irecv = pluginIrecv,
    .irecv_scattered = pluginIrecvScattered,
    .irecv_free_ptrs = pluginIrecvFreePtrs,
    .iflush = pluginIflush,
    .test = pluginTest,
    .closeSend = pluginCloseSend,
    .closeRecv = pluginCloseRecv,
    .closeListen = pluginCloseListen,
    .getDeviceMr = nullptr,
    .irecvConsumed = nullptr,
};
